{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Callbacks\n",
    "\n",
    "> Callbacks which work with a learner's data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CollectDataCallback(Callback):\n",
    "    \"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing\"\n",
    "    def before_fit(self): self.data = L()\n",
    "    def after_batch(self): \n",
    "        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CudaCallback(Callback):\n",
    "    \"Move data to CUDA device\"\n",
    "    def __init__(self, device=None): self.device = ifnone(device, default_device())\n",
    "    def before_batch(self): self.learn.xb,self.learn.yb = to_device(self.xb),to_device(self.yb)\n",
    "    def before_fit(self): self.model.to(self.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You don't normally need to use this Callback, because fastai's `DataLoader` will handle passing data to a device for you. However, if you already have a plain PyTorch DataLoader and can't change it for some reason, you can use this transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 14.484489440917969, 14.21982192993164, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#cuda\n",
    "learn = synth_learner(cbs=CudaCallback)\n",
    "learn.model\n",
    "learn.fit(1)\n",
    "test_eq(next(learn.model.parameters()).device.type, 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates()\n",
    "class WeightedDL(TfmdDL):\n",
    "    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        wgts = array([1.]*len(dataset) if wgts is None else wgts)\n",
    "        self.wgts = wgts/wgts.sum()\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.n==0: return []\n",
    "        if not self.shuffle: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.n, p=self.wgts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 160\n",
    "dsets = Datasets(torch.arange(n).float())\n",
    "dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)\n",
    "learn = synth_learner(data=dls, cbs=CollectDataCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, nan, None, '00:01']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD5CAYAAAA+0W6bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOb0lEQVR4nO3dfaxk9V3H8ffH5aEtYNh17+LKgxcb2khMupAroaIN5aECbXgwqSmxzZpitn8UA1q1S5vYGv+B2lL/MZitYDeWohseygbQFhElJM3Wu8jD4oLQdqVLt7sXUQH9wwJf/5iz7e3de3dm7527M7/yfiWTOed3zsx8djL3s2fOnDOTqkKS1J6fGHUASdLiWOCS1CgLXJIaZYFLUqMscElqlAUuSY06ot8KSd4EPAQc3a1/e1V9Kskq4G+ASWAX8OtV9Z8Hu6/Vq1fX5OTkEiNL0hvL9u3bX6iqibnj6XcceJIAx1TVK0mOBB4GrgF+DXixqq5PshFYWVUfP9h9TU1N1fT09KL/EZL0RpRke1VNzR3vuwulel7pZo/sLgVcBmzuxjcDlw8nqiRpEAPtA0+yIsmjwD7g/qraBpxQVXsAuus1y5ZSknSAgQq8ql6rqnXAScBZSX5h0AdIsiHJdJLpmZmZRcaUJM11SEehVNV/Af8IXATsTbIWoLvet8BtNlXVVFVNTUwcsA9ekrRIfQs8yUSS47vpNwMXAE8BW4H13WrrgbuXKaMkaR59DyME1gKbk6ygV/hbquqeJF8HtiS5CngOeP8y5pQkzdG3wKvqceCMecb/Azh/OUJJkvrzTExJapQFLkmNGmQfuCT9WJjceO/IHnvX9e8d+n26BS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWpU3wJPcnKSB5PsTPJkkmu68U8neT7Jo93lkuWPK0na74gB1nkV+FhVPZLkOGB7kvu7ZZ+vqs8uXzxJ0kL6FnhV7QH2dNMvJ9kJnLjcwSRJB3dI+8CTTAJnANu6oauTPJ7kliQrF7jNhiTTSaZnZmaWllaS9AMDF3iSY4E7gGur6iXgJuCtwDp6W+ifm+92VbWpqqaqampiYmLpiSVJwIAFnuRIeuV9a1XdCVBVe6vqtap6HfgCcNbyxZQkzTXIUSgBbgZ2VtWNs8bXzlrtCmDH8ONJkhYyyFEo5wAfAp5I8mg39gngyiTrgAJ2AR9ZhnySpAUMchTKw0DmWXTf8ONIkgblmZiS1KhBdqFI0lBNbrx31BF+LLgFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQof5VeYrS/kr7r+veO7LHVNrfAJalRFrgkNcoCl6RG9S3wJCcneTDJziRPJrmmG1+V5P4kz3TXK5c/riRpv0G2wF8FPlZVPw+cDXw0yenARuCBqjoNeKCblyQdJn0LvKr2VNUj3fTLwE7gROAyYHO32mbg8mXKKEmaxyHtA08yCZwBbANOqKo90Ct5YM0Ct9mQZDrJ9MzMzBLjSpL2G7jAkxwL3AFcW1UvDXq7qtpUVVNVNTUxMbGYjJKkeQxU4EmOpFfet1bVnd3w3iRru+VrgX3LE1GSNJ9BjkIJcDOws6punLVoK7C+m14P3D38eJKkhQxyKv05wIeAJ5I82o19Arge2JLkKuA54P3LklCSNK++BV5VDwNZYPH5w40jSRqUZ2JKUqP8NkLpDWqU38Co4XALXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKwwh1AH/gV2qDW+CS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUqL4FnuSWJPuS7Jg19ukkzyd5tLtcsrwxJUlzDbIF/kXgonnGP19V67rLfcONJUnqp2+BV9VDwIuHIYsk6RAsZR/41Uke73axrBxaIknSQBb7q/Q3AX8MVHf9OeDD862YZAOwAeCUU05Z5MPpjWJy472jjiA1Y1Fb4FW1t6peq6rXgS8AZx1k3U1VNVVVUxMTE4vNKUmaY1EFnmTtrNkrgB0LrStJWh59d6EkuQ04F1idZDfwKeDcJOvo7ULZBXxk+SJKkubTt8Cr6sp5hm9ehiySpEPgmZiS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEYt9keNJQ2JP+SsxXILXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKwwjHmIeXSToYt8AlqVEWuCQ1ygKXpEb1LfAktyTZl2THrLFVSe5P8kx3vXJ5Y0qS5hpkC/yLwEVzxjYCD1TVacAD3bwk6TDqW+BV9RDw4pzhy4DN3fRm4PLhxpIk9bPYfeAnVNUegO56zUIrJtmQZDrJ9MzMzCIfTpI017J/iFlVm6pqqqqmJiYmlvvhJOkNY7EFvjfJWoDuet/wIkmSBrHYAt8KrO+m1wN3DyeOJGlQgxxGeBvwdeDtSXYnuQq4HrgwyTPAhd28JOkw6vtdKFV15QKLzh9yFknSIfBMTElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSo/r+oINgcuO9o44gSQdwC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSo5Z0Kn2SXcDLwGvAq1U1NYxQkqT+hvFdKO+uqheGcD+SpEPgLhRJatRSC7yAryXZnmTDfCsk2ZBkOsn0zMzMEh9OkrTfUgv8nKo6E7gY+GiSd81doao2VdVUVU1NTEws8eEkSfstqcCr6rvd9T7gLuCsYYSSJPW36AJPckyS4/ZPA+8BdgwrmCTp4JZyFMoJwF1J9t/Pl6vq74aSSpLU16ILvKq+BbxjiFkkSYfAwwglqVHN/KixPywsST/KLXBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjllTgSS5K8nSSZ5NsHFYoSVJ/iy7wJCuAPwMuBk4Hrkxy+rCCSZIObilb4GcBz1bVt6rq/4C/Bi4bTixJUj9LKfATge/Mmt/djUmSDoMjlnDbzDNWB6yUbAA2dLOvJHl6wPtfDbywyGyj0FpeaC+zeZdfa5mbyZsbfjC5mMw/O9/gUgp8N3DyrPmTgO/OXamqNgGbDvXOk0xX1dTi4x1ereWF9jKbd/m1lrm1vDDczEvZhfLPwGlJTk1yFPABYOswQkmS+lv0FnhVvZrkauCrwArglqp6cmjJJEkHtZRdKFTVfcB9Q8oy1yHvdhmx1vJCe5nNu/xay9xaXhhi5lQd8LmjJKkBnkovSY0aeYEnOTnJg0l2JnkyyTXd+Kok9yd5prteOeqssyVZkeRfktzTzY973uOT3J7kqe65fuc4Z07yO93rYUeS25K8adzyJrklyb4kO2aNLZgxyXXd1048neRXxyTvn3SviceT3JXk+HHJu1DmWct+L0klWT1rbOye4278t7tMTyb5zNDyVtVIL8Ba4Mxu+jjg3+idmv8ZYGM3vhG4YdRZ5+T+XeDLwD3d/Ljn3Qz8Vjd9FHD8uGamd0LYt4E3d/NbgN8ct7zAu4AzgR2zxubN2L2mHwOOBk4FvgmsGIO87wGO6KZvGKe8C2Xuxk+mdwDFvwOrxyXzAs/xu4G/B47u5tcMK+/IXvwHeQLuBi4EngbWdmNrgadHnW1WxpOAB4DzZhX4OOf9ya4QM2d8LDPzw7N8V9H7oP2ermjGLi8wOeePdd6MwHXAdbPW+yrwzlHnnbPsCuDWccq7UGbgduAdwK5ZBT4Wmed5TWwBLphnvSXnHfkulNmSTAJnANuAE6pqD0B3vWaE0eb6U+APgNdnjY1z3p8DZoC/7Hb7/EWSYxjTzFX1PPBZ4DlgD/DfVfU1xjTvHAtlbOGrJz4M/G03PbZ5k1wKPF9Vj81ZNK6Z3wb8SpJtSf4pyS9240vOOzYFnuRY4A7g2qp6adR5FpLkfcC+qto+6iyH4Ah6b+tuqqozgP+h9/Z+LHX7jS+j97byZ4BjknxwtKmWbKCvnhiVJJ8EXgVu3T80z2ojz5vkLcAngT+cb/E8YyPPTO/vbyVwNvD7wJYkYQh5x6LAkxxJr7xvrao7u+G9SdZ2y9cC+0aVb45zgEuT7KL3DYznJfkS45sXev+z766qbd387fQKfVwzXwB8u6pmqur7wJ3ALzG+eWdbKONAXz0xCknWA+8DfqO69/KMb9630vuP/bHub/Ak4JEkP834Zt4N3Fk936D3zn01Q8g78gLv/ie6GdhZVTfOWrQVWN9Nr6e3b3zkquq6qjqpqibpfX3AP1TVBxnTvABV9T3gO0ne3g2dD/wr45v5OeDsJG/pXh/nAzsZ37yzLZRxK/CBJEcnORU4DfjGCPL9iCQXAR8HLq2q/521aCzzVtUTVbWmqia7v8Hd9A6C+B5jmhn4Cr3Py0jyNnoHEbzAMPKO4kOJOTvyf5ne24bHgUe7yyXAT9H7oPCZ7nrVqLPOk/1cfvgh5ljnBdYB093z/BV6b+nGNjPwR8BTwA7gr+h9Uj9WeYHb6O2j/z69IrnqYBnpvfX/Jr0POi8ek7zP0tsPu/9v78/HJe9Cmecs30X3IeY4ZF7gOT4K+FL3Wn4EOG9YeT0TU5IaNfJdKJKkxbHAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElq1P8DEK3bqTHPxe8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1)\n",
    "t = concat(*learn.collect_data.data.itemgot(0,0))\n",
    "plt.hist(t.numpy());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates()\n",
    "class PartialDL(TfmdDL):\n",
    "    \"Select randomly partial quantity of data at each epoch\"\n",
    "    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        self.partial_n = min(partial_n, self.n) if partial_n else None\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.partial_n is None: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.partial_n, replace=False))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.partial_n is None: return super().__len__()\n",
    "        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):\n",
    "    \"Create a partial dataloader `PartialDL` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.partial_dataloaders(partial_n=32, bs=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(dls[0])==2\n",
    "for batch in dls[0]:\n",
    "    assert len(batch[0])==16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
